题解 CF519E A and B and Lecture Rooms

题意:有棵大小为$n$的树,再给出m次询问,每次询问中包含$A,B$两点,我们要找到离$A,B$两点距离相等的点一共有多少个。

需要对$A,B$之间的距离进行分类讨论:

一.如果询问的两个点之间的距离为奇数(或者之间的点为偶数),那么无论怎样,它们之间必然有偶数个点,不可能有点到它们的距离相等。

二.如果询问的两个点之间的距离为偶数时,我们要找到$A,B$之间的中点,这个时候又需要分几个情况

{

$1.A,B$两点到他们的$LCA$的距离不相等(包括$A,B$两点中其中一个点为另一个点的$LCA$的情况),那么我们需要找到$A,B$两点所在链上的中点,中点与它不包含所询问点的子树上的点都是满足条件的点

$2.A,B$两点到他们的$LCA$之间距离相等时,满足条件的点的个数即是整棵树的上节点的总数减去$LCA$包含所询问两点的子树的节点个数

}

三.$A,B$两点重合时,整颗树上的点到这两个点的距离都可以看做相等。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include <bits/stdc++.h>
#define ll long long
#define sqr(x) ((x)*(x))
struct node{
int to,next;
}e[300000];
using namespace std;
inline void write(int x) {if (x<0) putchar('-'),x=-x;if (x>=10) write(x/10);putchar(x%10|'0');}
inline void wln(int x) {write(x);puts("");}
inline int read(){
int s=0,w=1;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
return s*w;
}
int tot,size[300000],head[300000],cnt,fa[300000],dep[300000],son[300000],top[300000],w[300000],p[3000000];
void add(int u,int v){
e[++cnt].to=v;
e[cnt].next=head[u];
head[u]=cnt;
}
void dfs1(int u){
size[u]=1;
for (int i=head[u];i;i=e[i].next){
int v=e[i].to;
if (fa[u]==v)continue;
fa[v]=u;dep[v]=dep[u]+1;
dfs1(v);
size[u]+=size[v];
if (!son[u]||size[v]>size[son[u]])
son[u]=v;
}
}
void dfs2(int u,int tp){
top[u]=tp;w[u]=++tot;p[tot]=u;
if (son[u])dfs2(son[u],tp);
for (int i=head[u];i;i=e[i].next){
int v=e[i].to;
if (v==fa[u]||v==son[u])continue;
dfs2(v,v);
}
}
inline int LCA(int x,int y){
while (top[x]!=top[y]){
if (dep[top[x]]>dep[top[y]])
x=fa[top[x]];
else y=fa[top[y]];
}
return dep[x]>dep[y]?y:x;
}//以上是树剖求LCA板子
int find_mid(int x,int y,int lca,int len,int far){
if (dep[x]-dep[lca]<len){
swap(x,y);
len=far-len;
}
while (x!=lca&&len>0){
int net=top[x];
if (w[x]-w[net]+1<=len)len-=w[x]-w[net]+1;
else return p[w[x]-len];
x=fa[net];
}
return x;
}//找mid这里应该用倍增,但是写了一种自己都看不懂的算法qwq
int main(){
int n=read();int m=n-1;
for (int i=1;i<=m;++i){
int u=read(),v=read();
add(u,v);add(v,u);
}
dfs1(1);
dfs2(1,1);
int q=read();
while (q--){
int x=read(),y=read();
int lca=LCA(x,y);
int far=(dep[x]+dep[y]-dep[lca]*2);
int mid=find_mid(x,y,lca,far/2,far);
int midl=find_mid(x,y,lca,far/2-1,far);
int midr=find_mid(y,x,lca,far/2-1,far);
if (far&1)
printf("0\n");
else if (x==y){
printf("%d\n",n);
}else if (dep[midl]==dep[midr]){
printf("%d\n",n-size[midl]-size[midr]);
}else{
printf("%d\n",size[mid]-(dep[midl]>dep[midr]?size[midl]:size[midr]));
}
}
return 0;
}